import json
import numpy as np
from pathlib import Path

def check_gradient_consistency(log_dir):
    """验证各rank间梯度统计一致性"""
    files = list(Path(log_dir).glob("rank*.json"))
    reference = json.load(open(files))
    
    for f in files[1:]:
        data = json.load(open(f))
        for param in reference["params"]:
            ref_norm = reference["params"][param]["norm"]
            curr_norm = data["params"][param]["norm"]
            assert np.isclose(ref_norm, curr_norm, rtol=1e-4), \
                f"Gradient mismatch on {param}: {ref_norm} vs {curr_norm}"
    
    print("✅ All gradient stats are consistent across ranks!")

if __name__ == "__main__":
    check_gradient_consistency("./grad_logs")
